import torch
import torch.nn as nn


class RewardModel(nn.Module):
    def __init__(self, embed_dim=768):  # Vit-L/14嵌入维度768
        super(RewardModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(embed_dim * 2, 1024),  # 两个embed拼起来
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    # def forward(self, x):
    #     return self.fc(x)

    def forward(self, image_embed, text_embed):
        combining_embedding = torch.cat((image_embed, text_embed), dim=-1)
        return self.fc(combining_embedding)
